#!/usr/bin/env python

import nest
import numpy as np
import pylab
import parameters as p

"""
    This network is the based on one developed for RSTDP controller by @Clamesc https://github.com/clamesc
"""

class SpikingNeuralNetwork():
    def __init__(self):
        np.set_printoptions(precision=1)
        nest.set_verbosity('M_WARNING')
        nest.ResetKernel()
        nest.SetKernelStatus({"local_num_threads" : 1, "resolution" : p.time_resolution})
        
        """Spike generators turning dvs input into spike trains"""
        self.spike_generators = nest.Create("poisson_generator", p.resolution[0]*p.resolution[1])
        
        """Parrot neurons, transmitting spikes generated by poisson generators"""
        self.neuron_pre_left = nest.Create("parrot_neuron", p.observation_neurons_count)
        
        """2 Output neurons: Left and Right"""
        self.neuron_post = nest.Create("iaf_psc_alpha", p.action_neurons_count, params=[p.neuronparams])
        
        """Spike detectors for output neurons"""
        self.spike_detector = nest.Create("spike_detector", p.action_neurons_count, params=[p.sd_params])
        
        """Volume transmitter is required by stdp_dopamine_synpase to inject dopamine current"""
        self.VolumeTransmitter = nest.Create("volume_transmitter")
        self.stdp_synapse_params = {"vt": self.VolumeTransmitter[0], 'A_plus' : 1., 'A_minus' : 1., 'Wmin' : 0., 'Wmax' : 3000., 'tau_c' : 1000., 'tau_n' : 200.}

        nest.SetDefaults("stdp_dopamine_synapse", self.stdp_synapse_params)
        
        nest.Connect(self.spike_generators, self.neuron_pre_left, "one_to_one")
        nest.Connect(self.neuron_pre_left, self.neuron_post, "all_to_all", p.syn_params)
        nest.Connect(self.neuron_post, self.spike_detector, "one_to_one")
        self.conn_l = nest.GetConnections(target=[self.neuron_post[0]])
        self.conn_r = nest.GetConnections(target=[self.neuron_post[1]])

    def simulate(self, dvs_data, reward_conditional, reward_collision):
        nest.SetStatus(self.conn_l, {"n": reward_conditional*p.reward_factor - reward_collision})
        nest.SetStatus(self.conn_r, {"n": -reward_conditional*p.reward_factor + reward_collision})
        
        
        time = nest.GetKernelStatus("time")
        
        nest.SetStatus(self.spike_generators, {"origin": time})
        nest.SetStatus(self.spike_generators, {"stop": p.sim_time})
        
        dvs_data = dvs_data.reshape(dvs_data.size)
        
        for i in range(dvs_data.size):
            rate = dvs_data[i]/p.max_spikes
            rate = np.clip(rate,0,1)*p.max_poisson_freq
            nest.SetStatus([self.spike_generators[i]], {"rate": rate})
        
        nest.Simulate(p.sim_time)
        
        n_l = nest.GetStatus(self.spike_detector,keys="n_events")[0]
        n_r = nest.GetStatus(self.spike_detector,keys="n_events")[1]
        
        nest.SetStatus(self.spike_detector, {"n_events": 0})
        weights_l = np.array(nest.GetStatus(self.conn_l, keys="weight")).reshape(p.resolution)
        weights_r = np.array(nest.GetStatus(self.conn_r, keys="weight")).reshape(p.resolution)
        
        return n_l, n_r, weights_l, weights_r

    def set_weights(self, weights_l, weights_r):
        w_l = []
        for w in weights_l.reshape(weights_l.size):
            w_l.append({'weight': w})
        
        w_r = []
        for w in weights_r.reshape(weights_r.size):
            w_r.append({'weight': w})
        
        nest.SetStatus(self.conn_l, w_l)
        nest.SetStatus(self.conn_r, w_r)
        
        return